package cnki.kg.algorithm.medium;

import org.junit.jupiter.api.Test;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class ThreeSum {

    @Test
    public void test() {
        //int[] nums=new int[]{-1,0,1,2,-1,-4};
        int[] nums = new int[]{-2, 0, 0, 2, 2};
        List<List<Integer>> lists = threeSum(nums);
        System.out.println(lists);
    }

    public List<List<Integer>> threeSum(int[] nums) {
        List<List<Integer>> result = new ArrayList<>();
        Arrays.sort(nums);
        // 找出a + b + c = 0
        // a = nums[i], b = nums[j], c = -(a + b)
        for (int i = 0; i < nums.length; i++) {
            if (nums[i] > 0) {
                return result;//排序后第一项就大于0那肯定没有和为0的三元组
            }
            //三元组元素a去重，已经参与的元素就跳过
            if (i > 0 && nums[i] == nums[i - 1]) {//因为是排序，所以后一个元素与前一个元素相等就是重复,那为啥不直接去除重复元素呢，-2,1,1加和也是0
                continue;
            }
            int left = i + 1;
            int right = nums.length - 1;
            while (right > left) {
                int sum = nums[i] + nums[left] + nums[right];
                if (sum > 0) {
                    right--;
                } else if (sum < 0) {
                    left++;
                } else {
                    result.add(Arrays.asList(nums[i], nums[left], nums[right]));
                    while (right > left && nums[right] == nums[right - 1]) {//右指针的当前位置的元素和他前面的元素一样，则把右指针往左移动
                        right--;
                    }
                    while (right > left && nums[left] == nums[left + 1]) {//同理左指针当前元素和后一位一样，则左指针往后移动
                        left++;//左指针前进并去重
                    }
                    //继续下一次指针移动
                    right--;
                    left++;
                }
            }
        }
        return result;
    }

    public List<List<Integer>> threeSum2(int[] nums) {
        List<List<Integer>> res = new ArrayList<>();
        if (nums == null || nums.length < 3) return res;
        Arrays.sort(nums);
        for (int i = 0; i < nums.length; i++) {
            if (nums[i] > 0) {//每次遍历的第一个元素就大于0,后边肯定没有符合的三元组
                return res;
            }
            if(i>0&&nums[i]==nums[i-1]){
                continue;
            }
            int left = i + 1;
            int right = nums.length - 1;
            int sum = 0;
            while (left < right) {
                sum = nums[i] + nums[left] + nums[right];
                if (sum > 0) {
                    right--;
                } else if (sum < 0) {
                    left++;
                } else {
                    List<Integer> ints = new ArrayList<>();
                    ints.add(nums[i]);
                    ints.add(nums[left]);
                    ints.add(nums[right]);
                    res.add(ints);
                    while (right > left && nums[left] == nums[left + 1]) {
                        left++;
                    }
                    while (right > left && nums[right] == nums[right - 1]) {
                        right--;
                    }

                    left++;
                    right--;
                }
            }
        }
        return res;
    }
}
